{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 14.7 来自Transformers的双向编码器表示（BERT）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tokens_and_segments(tokens_a, tokens_b=None):\n",
    "    \"\"\"获取输入序列的词元及其片段索引\"\"\"\n",
    "    tokens = ['<cls>'] + tokens_a + ['<sep>']\n",
    "    segments = [0] * (len(tokens_a) + 2)\n",
    "    if tokens_b is not None:\n",
    "        tokens += tokens_b + '<sep>'\n",
    "        segments += [1] * (len(tokens_b) + 2)\n",
    "    return tokens, segments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 14.7.4 输入表示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BERTEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,\n",
    "                ffn_num_hiddens, num_heads, num_layers, dropout, \n",
    "                max_len=1000, key_size=768, query_size=768, value_size=768, \n",
    "                **kwargs):\n",
    "        super(BERTEncoder, self).__init__(**kwargs)\n",
    "        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)\n",
    "        self.segment_embedding = nn.Embedding(2, num_hiddens)\n",
    "        self.blks = nn.Sequential()\n",
    "        for i in range(num_layers):\n",
    "            self.blks.add_module(f\"{i}\", d2l.EncoderBlock(\n",
    "                key_size, query_size, value_size, num_hiddens, norm_shape,\n",
    "                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))\n",
    "        # 在BERT中，位置嵌入是可学习的\n",
    "        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))\n",
    "        \n",
    "    def forward(self, tokens, segments, valid_lens):\n",
    "        # 以下的X保持形状不变 (batch_size, seq_len, num_hiddens)\n",
    "        X = self.token_embedding(tokens) + self.segment_embedding(segments)\n",
    "        X = X + self.pos_embedding.data[:, :X.shape[1], :]\n",
    "        for blk in self.blks:\n",
    "            X = blk(X, valid_lens)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4\n",
    "norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2\n",
    "encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,\n",
    "                     ffn_num_hiddens, num_heads, num_layers, dropout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BERTEncoder(\n",
       "  (token_embedding): Embedding(10000, 768)\n",
       "  (segment_embedding): Embedding(2, 768)\n",
       "  (blks): Sequential(\n",
       "    (0): EncoderBlock(\n",
       "      (attention): MultiHeadAttention(\n",
       "        (attention): DotProductAttention(\n",
       "          (dropout): Dropout(p=0.2, inplace=False)\n",
       "        )\n",
       "        (W_q): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (W_k): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (W_v): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (W_o): Linear(in_features=768, out_features=768, bias=True)\n",
       "      )\n",
       "      (addnorm1): AddNorm(\n",
       "        (dropout): Dropout(p=0.2, inplace=False)\n",
       "        (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "      (ffn): PositionWiseFFN(\n",
       "        (dense1): Linear(in_features=768, out_features=1024, bias=True)\n",
       "        (relu): ReLU()\n",
       "        (dense2): Linear(in_features=1024, out_features=768, bias=True)\n",
       "      )\n",
       "      (addnorm2): AddNorm(\n",
       "        (dropout): Dropout(p=0.2, inplace=False)\n",
       "        (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (1): EncoderBlock(\n",
       "      (attention): MultiHeadAttention(\n",
       "        (attention): DotProductAttention(\n",
       "          (dropout): Dropout(p=0.2, inplace=False)\n",
       "        )\n",
       "        (W_q): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (W_k): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (W_v): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (W_o): Linear(in_features=768, out_features=768, bias=True)\n",
       "      )\n",
       "      (addnorm1): AddNorm(\n",
       "        (dropout): Dropout(p=0.2, inplace=False)\n",
       "        (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "      (ffn): PositionWiseFFN(\n",
       "        (dense1): Linear(in_features=768, out_features=1024, bias=True)\n",
       "        (relu): ReLU()\n",
       "        (dense2): Linear(in_features=1024, out_features=768, bias=True)\n",
       "      )\n",
       "      (addnorm2): AddNorm(\n",
       "        (dropout): Dropout(p=0.2, inplace=False)\n",
       "        (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[9393, 8439, 1957, 8169, 2655, 3479, 6307, 6103],\n",
       "        [3273, 8047, 5569, 6225, 5540, 2778, 2242, 3942]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens = torch.randint(0, vocab_size, (2, 8))\n",
    "segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])\n",
    "encoded_X = encoder(tokens, segments, None)\n",
    "tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 8, 768])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoded_X.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 14.7.5 预训练任务"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 遮蔽语言模型Masked Language Modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaskLM(nn.Module):\n",
    "    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):\n",
    "        super(MaskLM, self).__init__(**kwargs)\n",
    "        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),\n",
    "                                nn.ReLU(),\n",
    "                                nn.LayerNorm(num_hiddens),\n",
    "                                nn.Linear(num_hiddens, vocab_size))\n",
    "        \n",
    "    def forward(self,X, pred_positions):\n",
    "        num_pred_positions = pred_positions.shape[1]\n",
    "        pred_positions = pred_positions.reshape(-1)\n",
    "        batch_size = X.shape[0]\n",
    "        batch_idx = torch.arange(0, batch_size)\n",
    "        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)\n",
    "        masked_X = X[batch_idx, pred_positions]\n",
    "        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))\n",
    "        mlm_Y_hat = self.mlp(masked_X)\n",
    "        return mlm_Y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3, 10000])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mlm = MaskLM(vocab_size, num_hiddens)\n",
    "mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])\n",
    "mlm_Y_hat = mlm(encoder_X, mlm_positions)\n",
    "mlm_Y_hat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])\n",
    "loss = nn.CrossEntropyLoss(reduction='none')\n",
    "mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))\n",
    "mlm_l.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 下一句预测 Next Sentence Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NextSentencePred(nn.Module):\n",
    "    def __init__(self, num_inputs, **kwargs):\n",
    "        super(NextSentencePred,self).__init__(**kwargs)\n",
    "        self.output = nn.Linear(num_inputs, 2)\n",
    "        \n",
    "    def forward(self, X):\n",
    "        # X的形状 (batch_size, num_hiddens)\n",
    "        return self.output(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoded_X = torch.flatten(encoded_X, start_dim=1)\n",
    "nsp = NextSentencePred(encoded_X.shape[-1])\n",
    "nsp_Y_hat = nsp(encoded_X)\n",
    "nsp_Y_hat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nsp_y = torch.tensor([0, 1])\n",
    "nsp_l = loss(nsp_Y_hat, nsp_y)\n",
    "nsp_l.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 14.7.6 把所有的东西放在一起"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BERTModel(nn.Module):\n",
    "    \"\"\"BERT模型\"\"\"\n",
    "    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, \n",
    "                num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768,\n",
    "                value_size=768, hid_in_features=768, mlm_in_features=768, nsp_in_features=768):\n",
    "        super(BERTModel, self).__init__()\n",
    "        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,\n",
    "                                  num_heads, num_layers, dropout, max_len=max_len, key_size=key_size, \n",
    "                                  query_size=query_size, value_size=value_size)\n",
    "        self.hidden = nn.Sequential(nn.Linear(hid_in_features,num_hiddens), nn.Tanh())\n",
    "        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)\n",
    "        self.nsp = NextSentencePred(nsp_in_features)\n",
    "        \n",
    "    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):\n",
    "        encoded_X = self.encoder(tokens, segments, valid_lens)\n",
    "        if pred_positions is not None:\n",
    "            mlm_Y_hat = self.mlm(encoded_X, pred_positions)\n",
    "        else:\n",
    "            mlm_Y_hat = None\n",
    "        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))\n",
    "        return encoded_X, mlm_Y_hat, nsp_Y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
